package it.uniroma2.dtk.dt.route;

import it.uniroma2.dtk.common.Spectrum;
import it.uniroma2.dtk.dt.DT;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.math.MatrixUtils;
import it.uniroma2.util.tree.Tree;
import it.uniroma2.util.vector.RandomVectorGenerator;
import it.uniroma2.util.vector.VectorComposer;
import it.uniroma2.util.vector.VectorProvider;

import java.util.HashMap;

/**
 * @author Lorenzo Dell'Arciprete
 * Implementation of DT interface for the Route Tree Kernel presented in:
 * 		Aiolli, F.; Da San Martino, G. & Sperduti, A. 
 * 		Route kernels for trees 
 * 		Proceedings of the 26th Annual International Conference on Machine Learning, ACM, 2009, 17-24
 * This class abstracts from the actual ideal composition function implementation.
 */
public abstract class RouteAbstractDT implements DT {

	protected VectorProvider vectorProvider;	// RandomVectorGenerator is the default implementation.
	
	protected int vectorSize;
	protected boolean useProductions;
	protected double lambdaSq = 1;	// Square root of lambda parameter
	
	protected HashMap<Integer, double[]> edgeVectors = new HashMap<Integer, double[]>();
	protected String edgeMarker = Character.toString((char)6);
	
	/**
	 * @param randomOffset - random seed for vector generation and composition
	 * @param vectorsSize - the dimension of the desired vector space
	 * @param useProductions - if true, POS information will be considered for the tree node labels 
	 * @param lambda - the value of the lambda decaying factor
	 * @throws Exception
	 */
	public RouteAbstractDT(int randomOffset, int vectorsSize, boolean useProductions, double lambda) throws Exception {
		vectorProvider = new RandomVectorGenerator(vectorsSize, randomOffset);
		this.vectorSize = vectorsSize;
		this.useProductions = useProductions;
		lambdaSq = Math.sqrt(lambda);
	}
	
	public RouteAbstractDT(int randomOffset, int vectorsSize, boolean useProductions) throws Exception {
		this(randomOffset, vectorsSize, useProductions, 1);
	}
	
	public RouteAbstractDT(int randomOffset, int vectorsSize, double lambda) throws Exception {
		this(randomOffset, vectorsSize, false, lambda);
	}
	
	/**
	 * The vector composition function
	 */
	public abstract double[] op(double[] v1, double[] v2) throws Exception;
	
	public int getVectorSize() {return vectorSize;}
	
	public VectorProvider getVectorProvider() {return vectorProvider;}
	
	public double[] dt(Tree x) {
		//The Spectrum object will collect the values of s(n) during the course of its computation
		Spectrum result = new Spectrum();
		result.setVector(MatrixUtils.uniformVector(vectorSize, 0));
		try {
			sRecursive(x, result);
		}
		catch(Exception e) {
			e.printStackTrace();
		}
		return result.getVector();
	}
	
	/**
	 * Recursive computation of function s(n) for the root of the input tree.
	 * To save time and space, the computed s(n) is directly added to the final sum.
	 * @param node - the input tree or, equivalently, its root node
	 * @param sum - the object collecting the sum of s(n) for each node in the tree
	 * @return s(node)
	 * @throws Exception
	 */
	protected double[] sRecursive(Tree node, Spectrum sum) throws Exception {
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		//Recursive computation of s(n).
		if (!node.isTerminal()) {
			for (int i=0; i<node.getChildren().size(); i++) {
				Tree child = node.getChildren().get(i);
				result = VectorComposer.sum(result, op(getEdgeVector(i), sRecursive(child, sum)));
			}
			result = ArrayMath.scalardot(lambdaSq, result);
		}
		if (useProductions)
			result = VectorComposer.sum(getProductionVector(node), result);
		else
			result = VectorComposer.sum(getLabelVector(node), result);
		//Adding s(node) to the final sum
		sum.setVector(VectorComposer.sum(sum.getVector(), result));
		return result;
	}
	
	public double[] dtf(Tree x) {
		System.err.println("Not available (positional information is missing)");
		return null;
	}
	
	public double[] getLabelVector(Tree node) throws Exception {
		return vectorProvider.getVector(node.getRootLabel());
	}
	
	public double[] getProductionVector(Tree node) throws Exception {
		double[] result = vectorProvider.getVector(node.getRootLabel());
		for (Tree child : node.getChildren())
			result = op(result, vectorProvider.getVector(child.getRootLabel()));
		return result;
	}
	
	/**
	 * Returns the standard vector representing an edge with the given positional index.
	 * The values are computed once and then stored, for efficiency reasons.
	 */
	public double[] getEdgeVector(int index) throws Exception {
		if (edgeVectors == null)
			edgeVectors = new HashMap<Integer, double[]>();
		if (edgeVectors.containsKey(index))
			return edgeVectors.get(index);
		else {
			double[] vector = vectorProvider.getVector(edgeMarker+index);
			edgeVectors.put(index, vector);
			return vector;
		}
	}
	
	public double getLambda() {
		return lambdaSq*lambdaSq;
	}
	
	public void setLambda(double lambda) {
		lambdaSq = Math.sqrt(lambda);
	}

	public boolean isUseProductions() {
		return useProductions;
	}

	public void setUseProductions(boolean useProductions) {
		this.useProductions = useProductions;
	}

}
